{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Minimising Lennard-Jones Atom Cluster Potentials with Evolution\n",
    "\n",
    "In this example, we'll demonstrate the application of evolution to a challenging, multi-modal black-box optimisation problem, the Lennard-Jones atom cluster potential minimisation task [1]. This problem has previously been studied with SNES [2]. The task is defined for $N$ atoms, each with a position $P_i = (x_i, y_i, z_i)$ in 3D space. The distance from atom $i$ to atom $j$ is,\n",
    "\n",
    "$$r_{i,j} = |P_i - P_j|$$\n",
    "\n",
    "The total atom cluster potential is,\n",
    "\n",
    "$$E = 4 \\epsilon \\sum_{i = 1}^N \\sum_{j = 1}^{i - 1} \\left(\\frac{\\sigma}{r_{i,j}}\\right)^{12} - \\left(\\frac{\\sigma}{r_{i,j}}\\right)^{6}$$\n",
    "\n",
    "where here we will use the reduced units [1], $\\epsilon = \\sigma = 1$. This function can be implemented in a vectorised fashion in PyTorch. Let's assume that the $N$ positions are described be a vector $x \\in \\mathbb{R}^{3N}$. Then we can define a vectorised computation of the distances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def pairwise_distances(positions: torch.Tensor) -> torch.Tensor:\n",
    "    # Assume positions has shape [B, 3N] where B is the batch size and N is the number of atoms\n",
    "    # Reshaping to get individual atoms' positions of shape [B, N, 3]\n",
    "    positions = positions.view(positions.shape[0], -1, 3)\n",
    "    # Compute the pairwise differences\n",
    "    # Subtracting [B, 1, N, 3] from [B, N, 1, 3] gives [B, N, N, 3] \n",
    "    deltas = (positions.unsqueeze(2) - positions.unsqueeze(1))\n",
    "    # Norm the differences gives [B, N, N]\n",
    "    distances = torch.norm(deltas, dim = -1)\n",
    "    return distances"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Which gives a straightforward definition of the cluster potential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cluster_potential(positions: torch.Tensor) -> torch.Tensor:\n",
    "    # Compute the pairwise distances of atoms\n",
    "    distances = pairwise_distances(positions)\n",
    "    \n",
    "    # Compute the pairwise cost (1 / dist)^12 - (1 / dist)^ 6\n",
    "    pairwise_cost = (1 / distances).pow(12) - (1 / distances).pow(6.)\n",
    "    \n",
    "    # Get the upper triangle matrix (ignoring the diagonal)\n",
    "    ut_pairwise_cost = torch.triu(pairwise_cost, diagonal = 1)\n",
    "    \n",
    "    # 4 * Summutation of the upper triangle of pairwise costs gives potential\n",
    "    potential = 4 * ut_pairwise_cost.sum(dim = (1, 2))\n",
    "    return potential"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "### Obtaining Reference Solutions\n",
    "To measure our performance, its helpful to refer to a reference point. For this, we will use a [publicly available database of known global optima](http://doye.chem.ox.ac.uk/jon/structures/LJ/tables.150.html). First, let's download the `tar` of optima and extract it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import tarfile\n",
    "\n",
    "# Url of tar containing known global minima\n",
    "url = 'http://doye.chem.ox.ac.uk/jon/structures/LJ/LJ.tar'\n",
    "# Where to save the tar -- modify as desired\n",
    "target_path = 'LJ_data.tar'\n",
    "\n",
    "# Download\n",
    "response = requests.get(url, stream=True)\n",
    "if response.status_code == 200:\n",
    "    with open(target_path, 'wb') as f:\n",
    "        f.write(response.raw.read())\n",
    "        \n",
    "# Open file\n",
    "file = tarfile.open(target_path)\n",
    "\n",
    "#  By default save the data to the 'LJ_data' folder in the local directory\n",
    "data_path = f'./{target_path.replace(\".tar\", \"\")}'\n",
    "file.extractall(data_path) \n",
    "  \n",
    "file.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "Now we can plot the computed potential of each solution obtained from the data base. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas\n",
    "\n",
    "# Lists to track atom counts and potentials\n",
    "atom_counts = []\n",
    "global_potentials = []\n",
    "\n",
    "# Iterate from 3 to 67 atoms -- the number visited in the paper\n",
    "for n_atoms in range(3, 68):\n",
    "    # File path is simply the nuimber of atoms\n",
    "    file_path = f'{data_path}/{n_atoms}'\n",
    "    # Get the positions as a dataframe\n",
    "    dataframe = pandas.read_csv(file_path, header=None, delim_whitespace=True)\n",
    "    # Make a positions tensor -- note that we add an initial dimension as the objective function is vectorised\n",
    "    positions = torch.Tensor(dataframe.to_numpy()).unsqueeze(0)\n",
    "    # Get the potential\n",
    "    potential = cluster_potential(positions)\n",
    "    \n",
    "    # Update lists of atom counts and potentials\n",
    "    atom_counts.append(n_atoms)\n",
    "    global_potentials.append(potential.item())\n",
    "    \n",
    "# Simple plot\n",
    "plt.plot(atom_counts, global_potentials)\n",
    "plt.xlabel('Number of Atoms $N$')\n",
    "plt.ylabel('Cluster potential $E$')\n",
    "plt.show()\n",
    "\n",
    "# Sanity check on the last one\n",
    "print(f'Potential of N={n_atoms} is computed as {potential.item()} vs. published value -347.252007')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "### Benchmarking SNES \n",
    "\n",
    "Let's consider benchmarking SNES on this problem. It is worth noting that evaluation can be done on the GPU due to how we've designed the function to minimise, so feel free to use `device = 'cuda:0'`, as we have done, if you have a cuda-capable device available. Otherwise, you should still see some speedup with `device = 'cpu'` due to PyTorch's efficient implementation of batched operators even on the CPU. In any case, we'll use a population size substantially higher than the default. We'll run the code below for only the first 17 atom clusters, but you can freely push this value higher for your own interest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from evotorch import Problem\n",
    "from evotorch.algorithms import SNES\n",
    "\n",
    "snes_potentials = []\n",
    "\n",
    "for n_atoms in range(3, 20):\n",
    "    \n",
    "    print(f'Solving case N={n_atoms} with SNES')\n",
    "    \n",
    "    # Set up the problem in vectorised mode\n",
    "    problem  = Problem(\n",
    "        'min',\n",
    "        cluster_potential,\n",
    "        vectorized = True,\n",
    "        device = 'cuda:0' if torch.cuda.is_available() else 'cpu',\n",
    "        dtype = torch.float64, # Higher precision could be helpful \n",
    "        solution_length = 3 * n_atoms,\n",
    "        initial_bounds = (-1e-12, 1e-12), # Taken directly from [2]\n",
    "        store_solution_stats = True,  # Make sure the problem tracks the best discovered solution, even on GPU\n",
    "    )\n",
    "    \n",
    "    searcher = SNES(\n",
    "        problem,\n",
    "        stdev_init = 0.01,  # Taken directly from [2]\n",
    "        popsize = 10 * n_atoms,  # Significantly higher than [2]\n",
    "        center_learning_rate = 0.5,  # Halving value from [2] slows learning\n",
    "        stdev_learning_rate = 0.5,  # Halving value from [2] slows learning\n",
    "        scale_learning_rate = True,  # Boolean flag means modifying the above learning rates rescales the default\n",
    "    )\n",
    "    searcher.run(1000 * problem.solution_length)   # 2x value used in [2], adjusted for half learning rates\n",
    "    \n",
    "    # Best solution found\n",
    "    best_potential = problem.status['best_eval']\n",
    "    # Center is also a good estimate\n",
    "    center_potential = cluster_potential(searcher.status['center'].cpu().unsqueeze(0))[0].item()\n",
    "    if center_potential < best_potential:\n",
    "        best_potential = center_potential\n",
    "        \n",
    "    print(f'Found potential {best_potential}')\n",
    "    \n",
    "    snes_potentials.append(best_potential)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "Finally let's take a look at how we did. We should see that for most cases, particularly for smaller atom clusters, SNES was either exactly recovering or was finding a solution very close to the known global optima of atom positions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple plot\n",
    "plt.plot(atom_counts[0:len(snes_potentials)], global_potentials[0:len(snes_potentials)], label = 'Known Optima')\n",
    "plt.plot(atom_counts[0:len(snes_potentials)], snes_potentials[0:len(snes_potentials)], label = 'SNES-discovered Solutions')\n",
    "plt.legend()\n",
    "plt.xlabel('Number of Atoms $N$')\n",
    "plt.ylabel('Atom Cluster Potential $E$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "#### References\n",
    "[1] Wales and Doye. [\"Global optimization by basin-hopping and the lowest energy structures of Lennard-Jones clusters containing up to 110 atoms.\"](https://pubs.acs.org/doi/abs/10.1021/jp970984n) The Journal of Physical Chemistry A 101.28 (1997)\n",
    "\n",
    "[2] Schaul et. al. [\"High dimensions and heavy tails for natural evolution strategies.\"](https://dl.acm.org/doi/abs/10.1145/2001576.2001692) Proceedings of the 13th annual conference on Genetic and evolutionary computation. 2011."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
